import numpy as np
import pylab as pl
import matplotlib_defaults
import sys, getopt
from scipy.optimize import leastsq

# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
# This script plots the differential changes experiment and
# provides am exponential fit to the relative changes under
# LTP and LTD.
# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #


# # # # # # # # # # # #
# # # P A R A M S # # #
# # # # # # # # # # # #

# Set filename
path = "data/"
fname = "TR75.npz"

#User-defined input.
fname = raw_input("Enter filename: ")

# params
#smooth_kernel = [0.05]*20
#delta_pulse = 10
#ltp_idx = (range(0,300), range(449,750))
#ltd_idx = (range(299,450), range(749,1200))

smoothing_kern_size = 20 #Overall size of smoothing kernel

delta_pulse = 100                      # trial difference for the 'relative change' plot (right)
smooth_kernel = [1.0/smoothing_kern_size]*smoothing_kern_size              # kernel for smoothing the data (blue line, left)
discard = len(smooth_kernel) / 2 + 1        # discard data corrupted by smoothing kernel
ltp_idx = (range(120*20,120*40), )                # list of ranges of LTP pulses. x*y: x = ProbNo*PulseNo; y denotes repetitions for start & end of LTP/LTD windows.
ltd_idx = (range(0,120*20), range(120*40,120*60))    # list of ranges of LTD pulses
ntr_idx = (range(120*60, 120*80), ) # list of ranges for neutral pulses.

# # # # # # # # # # # # # # # # # # #
# # # E N D   O F   P A R A M S # # #
# # # # # # # # # # # # # # # # # # #


# Load data and add to name space
X = np.load(path+fname)
for k,v in X.items():
    globals()[k] = v
param = param.tolist()
Smax = len(param['S'])

# prepare data
#G = (1./read_resist).mean(1)        # Mean conductance per trial
G = (1./write_resist).flatten()        # Mean conductance per trial
gsmooth = np.convolve(G, smooth_kernel, 'same')   # smooth the data
dgg_p = []                          # arrays for the relative changes
g0_p = []                           #  under Potentiation and Depression
dgg_d = []
g0_d = []
y = [0]

# fitting params [sign, g in uS of 1% change] initial values
fit_param_init_ltp = [1., np.mean(G)*(10**6)]
fit_param_init_ltd = [-1., np.mean(G)*(10**6)]
print(np.mean(G))

#Prepare functions and initial guesses for fittings.
f = lambda p, x: p[0] * np.exp(- p[1] * (x)) + p[2] # Fitting model function. p: parameter vector; x: x-data. Exponential fit.
#f = lambda p, x: p[0] * (x - p[1])**2 + p[2]*(x - p[1]) + p[3] #Poly-2 fit.
#f = lambda p, x: p[0]*np.log(x - p[1]) + p[2] # Log fit.
#f = lambda p, x: p[0]*np.tanh((x - p[1])*p[3]) + p[2] # tanh fit.
h = lambda p, x: y - f(p,x) #Residual error function: y: y-data.#Define initial fitting guess.
fitguess = np.array([0.0, 0.001, 0.0004, 0.0005]) #Array of initial guess parameters: [a, b, c, d] -> a*e^(-b*(x-c)) + d
#fitguess = np.array([0.0, 2400, 0.0, 0.0005]) #Array of initial guess parameters for poly2 fit.
#fitguess = np.array([1.1, 0.0, 0.0005]) #Array of initial guess parameters for log fit.
#fitguess = np.array([-0.000015, 4700.0, 0.0003, 1.0]) #Array of initial guess parameters for tanh fit.

#Prepare variables to store fittings.
fitltd = [0]*2
fitltdx = [0]*2
fitltdy = [0]*2

# calc data points for relative changes - also compute fittings.
for idx in ltp_idx:
    #idx = idx[discard:-discard] #Narrow ranges of processing in order to avoid data artefacts due to smoothing.
    dgg = (G[idx][delta_pulse:] - G[idx][:-delta_pulse]) / G[idx][:-delta_pulse] #Relative change.
    #dgg = (gsmooth[idx][delta_pulse:] - gsmooth[idx][:-delta_pulse]) / gsmooth[idx][:-delta_pulse] #Alternative relative change based on smoothed data.
    dgg_p += list(dgg)
    g0_p += list(G[idx][:-delta_pulse]) #Absolute conductance.

    y = G[idx]
    fitltp = leastsq(h, fitguess, args=(np.array(idx),))[0]
    fitltpx = np.array(idx)
    fitltpy = G[idx]
    print('LTP fitting params are: ', fitltp)

i = 0
for idx in ltd_idx:
    #idx = idx[discard:-discard] #Narrow ranges of processing in order to avoid data artefacts due to smoothing.
    dgg = (G[idx][delta_pulse:] - G[idx][:-delta_pulse]) / G[idx][:-delta_pulse] #Relative change.
    #dgg = (gsmooth[idx][delta_pulse:] - gsmooth[idx][:-delta_pulse]) / gsmooth[idx][:-delta_pulse] #Alternative relative change based on smoothed data.
    dgg_d += list(dgg)
    g0_d += list(G[idx][:-delta_pulse]) #Absolute conductance.
    
    y = G[idx]
    fitltd[i] = leastsq(h, fitguess, args=(np.array(idx),))[0]
    fitltdx[i] = np.array(idx)
    fitltdy[i] = G[idx]
    print('Fitting params for ', i, ' are: ', fitltd[i])
    i = i+1


for idx in ntr_idx:
    y = G[idx]
    fitntr = leastsq(h, fitguess, args=(np.array(idx),))[0]
    fitntrx = np.array(idx)
    fitntry = G[idx]
    print('Fit neutral: ', fitntr)
    print('Initial and final fit values (uS): ', f(fitntr, idx[1])*(10**6), f(fitntr, idx[-1])*(10**6))
    print('STD in neutral section of TR (not residue vs. fit): ', np.std(G[idx]))
    print('Normalised STD %: ', 100*np.std(G[idx])/np.mean(G[idx]))
    histnoise = np.histogram(G[idx])
    print(histnoise)

    #Save neutral region data to file for further processing.
    noiselog = open('workfile.txt', 'w')
    for item in fitntry:
        noiselog.write("%s\n" % item)
    noiselog.close()





# # # # # # # # # # #
# # # FIGURE 1  # # #
# # # # # # # # # # #

# Prepare Figure and Axes for plotting
pl.rc("font", size=8)
pl.rc("figure", dpi=110)
pl.interactive(False)
fig = pl.figure(figsize=(7,2.5))
ax_smooth = fig.add_axes((0.08,0.17,0.38,0.75))
ax_dgg = fig.add_axes((0.60,0.17,0.38,0.75))

# PLOT: Smooth plot (left)
ax = ax_smooth
discard = len(smooth_kernel) / 2 + 1        # discard data corrupted by smoothing kernel
x = np.arange(len(gsmooth) - 2 * discard)
y1 = G[discard:-discard] * 10**6            # raw data in uS
y2 = gsmooth[discard:-discard] * 10**6      # smoothed data in uS

ax.plot(x, y1, '0.2', lw=0.5)
ax.plot(x, y2, 'b', lw=1.)
# fmt axes
ax.set_xlabel("Trial number")
ax.set_xlim(x.min(), x.max())
ax.set_ylabel("Conductance $g(n)$ [uS]")
ymin, ymax = y1.min(), y1.max()
dy = ymax - ymin
ax.set_ylim(ymin - 0.1*dy, ymax + 0.1*dy)
#ax.yaxis.get_major_formatter().set_powerlimits((0, 1))

# Add fittings.
ax.plot(fitltpx, f(fitltp,fitltpx)*10**6, 'r', lw=1.) #Don't forget to correct for units.
for i in range(0,len(ltd_idx)):
    ax.plot(fitltdx[i], f(fitltd[i],fitltdx[i])*10**6, 'r', lw=1.) #Don't forget to correct for units.
ax.plot(fitntrx, f(fitntr,fitntrx)*10**6, 'r', lw=1.) #Don't forget to correct for units.


# rectangles marking the LTP and LTD regions
for idx in ltp_idx:
    c = "tomato"
    idx = np.array(idx) - discard
    x1,x2 = idx[0], idx[-1]
    y1,y2 = ax.get_ylim()
    p = pl.Rectangle((x1,y1), x2-x1, y2-y1, lw=0, fc=c, ec=c, alpha=0.4)
    ax.add_patch(p)
    x,y = idx.mean(), y1 + 0.1 * dy
    ax.text(x, y, "LTP", color='red', weight="semibold", ha='center', va='baseline')

for idx in ltd_idx:
    c = "lightblue"
    idx = np.array(idx) - discard
    x1,x2 = idx[0], idx[-1]
    y1,y2 = ax.get_ylim()
    p = pl.Rectangle((x1,y1), x2-x1, y2-y1, lw=0, fc=c, ec=c, alpha=0.4)
    ax.add_patch(p)
    x,y = idx.mean(), y1 + 0.1 * dy
    ax.text(x, y, "LTD", color='blue', weight="semibold", ha='center', va='baseline')


# PLOT: delta g over g_0 (right)
ax = ax_dgg

from scipy.optimize import leastsq
h = lambda p,x: np.sign(p[0]) * 0.01 * np.exp(- p[0] * (x - p[1])) + 0. # np.sign(p[0]) * abs(p[2]) 
f = lambda p,x: y / 100. - h(p,x) # Residual function: actual data - model function.

# LTP PLOT
x = np.array(g0_p) * 10**6
xmin,xmax = x.min(), x.max()
y = np.array(dgg_p) * 100
ymin,ymax = y.min(), y.max()
ax.plot(x, y, '.', c='tomato', ms=1.5)
# fitting
x0 = np.array(fit_param_init_ltp)
p = leastsq(f, x0, args=(x,))[0] #Least squares fit. -> f: residual function, x0: initial guess, args=(x,): data to be fitted.
xfit = np.linspace(x.min(), x.max(), 100)
ax.plot(xfit, 100 * h(p,xfit), 'r', lw=1.)
print " > LTP fit params:", p
pp = p

# LTD PLOT
x = np.array(g0_d) * 10**6
xmin,xmax = min(x.min(), xmin), max(x.max(), xmax)
y = np.array(dgg_d) * 100
ymin,ymax = min(y.min(), ymin), max(y.max(), ymax)
ax.plot(x, y, '.', c='lightblue', ms=1.5)
# fitting
x0 = np.array(fit_param_init_ltd)
p = leastsq(f, x0, args=(x,))[0]
xfit = np.linspace(x.min(), x.max(), 100)
ax.plot(xfit, 100 * h(p,xfit), 'b', lw=1.)
print " > LTD fit params:", p
pd = p

# fmt axes
ax.set_xlabel("Conductance $g(n)$ [uS]")
ax.set_xlim(xmin, xmax)
ax.set_ylabel(r"Relative change $\frac{g(n+%d) - g(n)}{g(n)}$" % delta_pulse + " [%]")
dy = ymax - ymin
ax.set_ylim(ymin - 0.1*dy, ymax + 0.1*dy)

# zero line
x1,x2 = ax.get_xlim()
ax.plot((x1,x2), (0,0), '--', c='0.2', lw=1.0)


# # # # # # # # # # #
# # # FIGURE 2  # # #
# # # # # # # # # # #

# Figure f(g)
fig2 = pl.figure(figsize=(3.5,2.5)) 
xmin, xmax = min(min(g0_d), min(g0_p)) * 10**6, max(max(g0_d), max(g0_p)) * 10**6
x = np.linspace(xmin,xmax,100)
y =  -h(pd,x) / (h(pp,x) - h(pd,x))
pl.plot(x, y, 'r', lw=1.)
pl.xlabel("Conductance $g$ [uS]")
pl.ylabel("Plasticity function $f(g)$")
pl.xlim(xmin, xmax)
ymin, ymax = y.min(), y.max()
dy = ymax - ymin
pl.ylim(ymin - 0.1*dy, ymax + 0.1*dy)
pl.subplots_adjust(0.15,0.17,0.95,0.95)

pl.interactive(True)
pl.show()



#def test_fig(ax):
    #ax2 = ax.twinx()
    #xmin, xmax = ax2.get_xlim()
    #y = 1. / write_resist.flatten() * 10**6
    #x = np.linspace(xmin, xmax, len(y))
    #ax2.plot(x, y, ':r', lw=0.25)
    #ax2.set_xlim(xmin, xmax)
    #ax2.set_ylim(*ax.get_ylim())
    #return ax2

# plot function
# NOT USED IN THE CURRENT SCRIPT
#def draw_subplot(s,ax):
    #idx = (param['S'] == s).nonzero()[0][0]
    #fmt = lambda dt,v: ( dt=="int" and ("%d" % v) or ("%.2f" % v) )
    #string = [ fmt(dt,param[k][idx]) for k,dt in zip(param_name,param_dtype)]
    #string = "(" + ", ".join(string) + ")"
    #ax.set_title(string, fontsize=7)
    ## conductances
    #x = np.arange(num_write)
    #y = 1. / write_resist[idx]
    #ax.plot(x, y, 'b', lw=0.5)
    #ymin,ymax = y.min(),y.max()
    #dy = ymax - ymin
    #ax.set_ylim(ymin-0.1*dy, ymax+0.1*dy)
    ## pulses
    #x = (write_pulse[idx] == 1).nonzero()[0]
    #y = [ymin] * len(x)
    #ax.plot(x, y, 'r.', ms=1)
    #ax.yaxis.get_major_formatter().set_powerlimits((0, 1))
    #ax.set_xlim(0,num_write)

raw_input('Press ENTER to exit')